# Python Version 2.7

# Author: Yahui Zhang, Jixiang Yang
# 2022.1.20

import math
from scipy.interpolate import griddata
from scipy import interpolate
from matplotlib import colors, cm, pyplot as plt
import matplotlib.patches as patches
import numpy as np
from numpy import linalg as LA
import csv
import matplotlib
import os
import shutil
import time
from scipy import signal

# n-layer graphene aligned with hBN. for TLG, please just use n_layer=3
# calculation of the optical conductivity

# These are hopping parameters for ABC TLG itself

# Parameter Group 6:
path_head = "para6"
t = -3250.8*np.sqrt(3)/2  # in unit of meV
t3 = 293*np.sqrt(3)/2  # 380#trigonal warping
t4 = -140*np.sqrt(3)/2
gamma1 = 380
gamma2 = -8.3  # -8.3

# Numbers in the BZ: N_k*N_k
N_k = 250

# Energy range of the spectrum: w_low to w_low+w_range-1
# Number of points: N_w
N_w = 480
w_low = 2.0
w_range = 120.0

# D_min
D_min = -100
N_D = 201

# These are moire superlattice potential parameters

VN = -1493  # nitride on-site energy
VB = 3332  # boron on-site energy
tB = 144  # moire hopping between graphene and boron;
tN = 97  # moire hopping between graphene and nitride;

FB = tB*tB/VB
FN = tN*tN/VN

# Check [NEARLY FLAT CHERN BANDS IN MOIRE SUPERLATTICES PHYSICAL REVIEW B 99, 075127 (2019)] for details
n_layer = 3
lambda_M = 1

C0 = -10.13
phi_0 = 86.53/180*np.pi

Cz = -9.01
phi_z = 8.43/180*np.pi

Cab = 11.34
phi_ab = 19.60/180*np.pi


def get_P_path(D, path_head="para1", N_tran=4):
    # formalize the path for the calculation results
    PN = "P"+str(N_tran)
    gapN = "gap"+str(N_tran)
    path_pplus = os.path.join(
        path_head, "data", PN, PN+'_plus_list_D_'+str(D*1.0)+'.npy')
    path_pminus = os.path.join(
        path_head, "data", PN, PN+'_minus_list_D_'+str(D*1.0)+'.npy')
    path_gap = os.path.join(
        path_head, "data", PN, gapN+'_list_D_'+str(D*1.0)+'.npy')

    return path_pplus, path_pminus, path_gap


def get_P_n_path(D, n=2, path_head="para5"):
    # formalize the path for the calculation results
    _n = str(int(n))
    path_pplus = os.path.join(
        path_head, "data", "P_n", 'P_'+_n+'_plus_list_D_'+str(D*1.0)+'.npy')
    path_pminus = os.path.join(
        path_head, "data", "P_n", 'P_'+_n+'_minus_list_D_'+str(D*1.0)+'.npy')
    path_gap = os.path.join(
        path_head, "data", "P_n", 'gap_'+_n+'_list_D_'+str(D*1.0)+'.npy')
    path_e = os.path.join(
        path_head, "data", "P_n", 'E_'+_n+'_list_D_'+str(D*1.0)+'.npy')
    return path_pplus, path_pminus, path_gap, path_e


def is_block_diagonal(M):
    # Check if the Hamiltonian is Hermitian
    M = np.array(M)
    L1 = len(M)
    L2 = len(M[0])
    assert L1 == L2
    assert L1 % 6 == 0
    nn = int(L1/6)
    flag = True
    for i in range(nn):
        for j in range(nn):
            if i == j:
                continue
            h = M[i*6:i*6+6, j*6:j*6+6]
            if sum(sum(abs(h))) != 0:
                print i, j
                flag = False

    return flag


class GBN(object):  # graphene on boron Nitride
    def __init__(self):

        print("start calculation!")
        print('number of layer: ', n_layer)

        print('parameters (t, gamma1, gamma2, t3, t4):')
        print t*2/np.sqrt(3), gamma1, gamma2, t3*2/np.sqrt(3), t4*2/np.sqrt(3)

        self.D = 0  # displacement field
        self.U = self.D  # voltage difference

        self.M = 5  # cut off; we are doing free electron approximation,  k scatter to  k+ m G_1 +n G_2;    We do truncation |m|,|n|<=M/2

        self.aM = 58.8  # Moire superlattice constant; in unit of graphene lattice constant a_0 		# a_0=2.46 anstrom, a_M=14.6 nm
        self.kappa = 1.

        # The first reciprocal vecotr
        self.Gx1 = 0
        self.Gy1 = -4*math.pi/(math.sqrt(3)*self.aM)

        # The second reciprocal vector
        self.Gx2 = 2*math.pi/self.aM
        self.Gy2 = -2*math.pi/(math.sqrt(3)*self.aM)

        w = np.exp(1j*2*math.pi/3)
        self.alpha = 1.0

        # parameters from Jeil Jung's paper
        self.T1 = np.zeros((2, 2), dtype=complex)
        self.T2 = np.zeros((2, 2), dtype=complex)
        self.T3 = np.zeros((2, 2), dtype=complex)

        self.T1[0, 0] = C0*np.exp(1j*phi_0)+Cz*np.exp(1j*phi_z)
        self.T1[1, 1] = C0*np.exp(1j*phi_0)-Cz*np.exp(1j*phi_z)
        self.T1[0, 1] = Cab*np.exp(1j*(2*math.pi/3-phi_ab))
        self.T1[1, 0] = Cab*np.exp(1j*(2*math.pi/3-phi_ab))

        self.T2[0, 0] = C0*np.exp(-1j*phi_0)+Cz*np.exp(-1j*phi_z)
        self.T2[1, 1] = C0*np.exp(-1j*phi_0)-Cz*np.exp(-1j*phi_z)
        self.T2[0, 1] = Cab*np.exp(1j*phi_ab)
        self.T2[1, 0] = Cab*np.exp(1j*(2*math.pi/3+phi_ab))

        self.T3[0, 0] = C0*np.exp(1j*phi_0)+Cz*np.exp(1j*phi_z)
        self.T3[1, 1] = C0*np.exp(1j*phi_0)-Cz*np.exp(1j*phi_z)
        self.T3[0, 1] = Cab*np.exp(-1j*phi_ab)
        self.T3[1, 0] = Cab*np.exp(-1j*(2*math.pi/3+phi_ab))

        self.T4 = self.T1.conj().T
        self.T5 = self.T2.conj().T
        self.T6 = self.T3.conj().T

        self.initialize_Vm(lambda_M)

    def initialize_Vm(self, lambda_M):
        self.T1 = self.T1*lambda_M
        self.T2 = self.T2*lambda_M
        self.T3 = self.T3*lambda_M
        self.T4 = self.T4*lambda_M
        self.T5 = self.T5*lambda_M
        self.T6 = self.T6*lambda_M

    def set_D(self, D):
        D = D*1.0
        self.D = D
        self.U = D

    def get_h(self, kx, ky):  # Hamiltonian for ABC trialyer itself; #B_1 dimerize with A_2

        h = np.zeros((2*n_layer, 2*n_layer), dtype=complex)

        n_mid = (n_layer-1)*0.5
        for i in range(n_layer):
            if n_layer == 1:
                h[2*i, 2*i] = -self.U
                h[2*i+1, 2*i+1] = -self.U
            else:
                h[2*i, 2*i] = -self.U/(n_layer-1)*(i-n_mid)
                h[2*i+1, 2*i+1] = -self.U/(n_layer-1)*(i-n_mid)

            h[2*i, 2*i+1] = -t*(kx-1j*ky)
            h[2*i+1, 2*i] = -t*(kx+1j*ky)

            if i < n_layer-1:
                h[2*i+1, 2*i+2] = gamma1
                h[2*i+2, 2*i+1] = gamma1

                h[2*i, 2*i+3] = -t3*(kx+1j*ky)
                h[2*i+3, 2*i] = -t3*(kx-1j*ky)

                h[2*i, 2*i+2] = -t4*(kx-1j*ky)
                h[2*i+2, 2*i] = -t4*(kx+1j*ky)

                h[2*i+1, 2*i+3] = -t4*(kx-1j*ky)
                h[2*i+3, 2*i+1] = -t4*(kx+1j*ky)

            if i < n_layer-2:
                h[2*i, 2*i+5] = gamma2
                h[2*i+5, 2*i] = gamma2

        return h

    def get_h_dx(self, kx, ky):  # B_1 dimerize with A_2

        h = np.zeros((2*n_layer, 2*n_layer), dtype=complex)

        for i in range(n_layer):

            h[2*i, 2*i+1] = -t
            h[2*i+1, 2*i] = -t

            if i < n_layer-1:

                h[2*i, 2*i+3] = -t3
                h[2*i+3, 2*i] = -t3

                h[2*i, 2*i+2] = -t4
                h[2*i+2, 2*i] = -t4

                h[2*i+1, 2*i+3] = -t4
                h[2*i+3, 2*i+1] = -t4

        return h

    def get_h_dy(self, kx, ky):  # B_1 dimerize with A_2
        h = np.zeros((2*n_layer, 2*n_layer), dtype=complex)

        for i in range(n_layer):
            h[2*i, 2*i+1] = 1j*t
            h[2*i+1, 2*i] = -1j*t

            if i < n_layer-1:
                h[2*i, 2*i+3] = -t3*(1j)
                h[2*i+3, 2*i] = -t3*(-1j)

                h[2*i, 2*i+2] = -t4*(-1j)
                h[2*i+2, 2*i] = -t4*(1j)

                h[2*i+1, 2*i+3] = -t4*(-1j)
                h[2*i+3, 2*i+1] = -t4*(1j)

        return h

    def get_H(self, kx, ky):  # Moire Hamiltonian

        M = self.M          # NOTE: In python2, 5/2=2
        medium_M = self.M/2
        nn = 2*n_layer      # h(k) is nn*nn

        H = np.zeros((nn*self.M*self.M, nn*self.M*self.M), dtype=complex)

        for i in range(self.M*self.M):
            m = i/self.M
            n = i % self.M

            kkx = kx+(m-medium_M)*self.Gx1+(n-medium_M)*self.Gx2
            kky = ky+(m-medium_M)*self.Gy1+(n-medium_M)*self.Gy2

            H[nn*i:nn*i+nn, nn*i:nn*i+nn] = self.get_h(kkx, kky)

            # The following are moire potential on the top graphene layer; we assume that the hbn in the top is aligned.

            j = self.PBE(m+1)*self.M+n
            H[nn*i:nn*i+2, nn*j:nn*j+2] = self.T1

            j = m*self.M+self.PBE(n+1)
            H[nn*i:nn*i+2, nn*j:nn*j+2] = self.T2

            j = self.PBE(m-1)*self.M+self.PBE(n+1)
            H[nn*i:nn*i+2, nn*j:nn*j+2] = self.T3

            j = self.PBE(m-1)*self.M+n
            H[nn*i:nn*i+2, nn*j:nn*j+2] = self.T4

            j = m*self.M+self.PBE(n-1)
            H[nn*i:nn*i+2, nn*j:nn*j+2] = self.T5

            j = self.PBE(m+1)*self.M+self.PBE(n-1)
            H[nn*i:nn*i+2, nn*j:nn*j+2] = self.T6

        if np.sum(np.abs(H-H.conj().T)) > 0.000000001:
            print("not hermitian", np.sum(np.abs(H-H.conj().T)))

        return H

    def get_H0(self, kx, ky):  # No Moire interaction - Hamiltonian
        M = self.M

        medium_M = self.M/2

        nn = 2*n_layer

        H = np.zeros((nn*self.M*self.M, nn*self.M*self.M), dtype=complex)

        for i in range(self.M*self.M):
            m = i/self.M
            n = i % self.M

            kkx = kx+(m-medium_M)*self.Gx1+(n-medium_M)*self.Gx2
            kky = ky+(m-medium_M)*self.Gy1+(n-medium_M)*self.Gy2

            H[nn*i:nn*i+nn, nn*i:nn*i+nn] = self.get_h(kkx, kky)

        if np.sum(np.abs(H-H.conj().T)) > 0.000000001:
            print("not hermitian", np.sum(np.abs(H-H.conj().T)))

        return H

    # calculate derivative \parital H(k)/ \partial_k_x;  The purpose is to calculate the Berry curvature
    def get_H_dx(self, kx, ky):

        M = self.M

        medium_M = self.M/2

        nn = 2*n_layer

        H = np.zeros((nn*self.M*self.M, nn*self.M*self.M), dtype=complex)

        for i in range(self.M*self.M):
            m = i/self.M
            n = i % self.M

            kkx = kx+(m-medium_M)*self.Gx1+(n-medium_M)*self.Gx2
            kky = ky+(m-medium_M)*self.Gy1+(n-medium_M)*self.Gy2

            H[nn*i:nn*i+nn, nn*i:nn*i+nn] = self.get_h_dx(kkx, kky)

        if np.sum(np.abs(H-H.conj().T)) > 0.000000001:
            print("not hermitian", np.sum(np.abs(H-H.conj().T)))

        return H

    def get_H_dy(self, kx, ky):

        M = self.M

        medium_M = self.M/2

        nn = 2*n_layer

        H = np.zeros((nn*self.M*self.M, nn*self.M*self.M), dtype=complex)

        for i in range(self.M*self.M):
            m = i/self.M
            n = i % self.M

            kkx = kx+(m-medium_M)*self.Gx1+(n-medium_M)*self.Gx2
            kky = ky+(m-medium_M)*self.Gy1+(n-medium_M)*self.Gy2

            H[nn*i:nn*i+nn, nn*i:nn*i+nn] = self.get_h_dy(kkx, kky)

        if np.sum(np.abs(H-H.conj().T)) > 0.000000001:
            print("not hermitian", np.sum(np.abs(H-H.conj().T)))

        return H

    def PBE(self, i):
        return np.int((i+self.M) % self.M)

    def dispersion_plot(self, band_num=2, with_moire=True):  # Plot the dispersion
        # Draw 2*band_num bands in the middle
        nn = 2*n_layer

        def get_energy_list(start_kx, start_ky, total_kx, total_ky, num_k):
            energy_list = np.zeros((num_k, 2*band_num))

            for i in range(num_k):
                kx = start_kx+i*total_kx/num_k
                ky = start_ky+i*total_ky/num_k

                if with_moire:
                    H = self.get_H(kx, ky)
                else:
                    H = self.get_H0(kx, ky)
                # H = self.get_h(kx, ky)

                energy, states = LA.eigh(H)
                energy = np.real(energy)
                # states=np.asarray(states)

                sort_perm = energy.argsort()

                energy = energy[sort_perm]
                # states = states[:, sort_perm]
                energy_list[i, :] = energy[n_layer *
                                           self.M*self.M-band_num:n_layer*self.M*self.M+band_num]

            return energy_list

        start_kx = 0
        end_kx = -4*math.pi/(3*self.aM)
        start_ky = 0
        end_ky = 0

        total_kx = end_kx-start_kx
        total_ky = end_ky-start_ky

        num_k = 100
        k_list1 = np.arange(num_k)
        energy_list1 = get_energy_list(
            start_kx, start_ky, total_kx, total_ky, num_k)

        start_kx = -4*math.pi/(3*self.aM)
        end_kx = -2*math.pi/(3*self.aM)
        start_ky = 0
        end_ky = 2*math.pi/(math.sqrt(3)*self.aM)

        total_kx = end_kx-start_kx
        total_ky = end_ky-start_ky

        num_k = 100
        k_list2 = np.arange(num_k)
        energy_list2 = get_energy_list(
            start_kx, start_ky, total_kx, total_ky, num_k)

        start_kx = -2*math.pi/(3*self.aM)
        end_kx = 4*math.pi/(3*self.aM)
        start_ky = 2*math.pi/(math.sqrt(3)*self.aM)
        end_ky = 0

        total_kx = end_kx-start_kx
        total_ky = end_ky-start_ky

        num_k = 150
        k_list3 = np.arange(num_k)
        energy_list3 = get_energy_list(
            start_kx, start_ky, total_kx, total_ky, num_k)

        num_k = 100
        start_kx = 4*math.pi/(3*self.aM)
        end_kx = 0
        start_ky = 0
        end_ky = 0

        total_kx = end_kx-start_kx
        total_ky = end_ky-start_ky
        k_list4 = np.arange(num_k)
        energy_list4 = get_energy_list(
            start_kx, start_ky, total_kx, total_ky, num_k)

        k_list = np.append(k_list1, k_list2+100)
        energy_list = np.append(energy_list1, energy_list2, axis=0)

        k_list = np.append(k_list, k_list3+200)
        energy_list = np.append(energy_list, energy_list3, axis=0)
        k_list = np.append(k_list, k_list4+350)
        energy_list = np.append(energy_list, energy_list4, axis=0)

        plt.rc('text', usetex=False)
        plt.rc('axes', linewidth=3)
        plt.rc('font', weight='bold')
        plt.rcParams['text.latex.preamble'] = [
            r'\usepackage{sfmath} \boldmath']

        print "W: ", np.max(
            energy_list[:, band_num-1])-np.min(energy_list[:, band_num-1])
        print "gap: ", np.min(
            energy_list[:, band_num]-energy_list[:, band_num-1])
        print "full fill gap: ", np.min(
            energy_list[:, band_num-1])-np.max(energy_list[:, band_num-2])

        plt.figure()
        plt.plot(k_list, energy_list[:, :], linewidth=3)
        plt.title(path_head)
        plt.xticks([0, 100, 200, 350, 450], [
                   r'$\Gamma$', r'$K^{\prime}$', r'$K^{\prime \prime}$', r'$K$', r'$\Gamma$'], fontsize=20)

        plt.xlim(0, 500)
        plt.yticks(fontsize=20)
        plt.axvline(100)
        plt.axvline(200)
        plt.axvline(350)
        plt.xlim(0, 450)

        plt.savefig(os.path.join(
            path_head, 'image', 'dispersion', 'D_'+str(self.D)+'.png'))

    def save_dispersion(self, band_num=4, with_moire=True):
        # Save the dispersion
        # M-Gamma-K-K'
        # Points: 86-100-100
        # Draw 2*band_num bands in the middle

        nn = 2*n_layer

        def get_energy_list(start_kx, start_ky, total_kx, total_ky, num_k):
            energy_list = np.zeros((num_k, 2*band_num))

            for i in range(num_k):
                kx = start_kx+i*total_kx/num_k
                ky = start_ky+i*total_ky/num_k

                if with_moire:
                    H = self.get_H(kx, ky)
                else:
                    H = self.get_H0(kx, ky)

                energy, states = LA.eigh(H)
                energy = np.real(energy)

                sort_perm = energy.argsort()

                energy = energy[sort_perm]
                # states = states[:, sort_perm]
                energy_list[i, :] = energy[n_layer *
                                           self.M*self.M-band_num:n_layer*self.M*self.M+band_num]

            return energy_list

        # M to Gamma
        start_kx, start_ky = 0, -2*math.pi/(math.sqrt(3)*self.aM)
        end_kx, end_ky = 0, 0

        total_kx = end_kx-start_kx
        total_ky = end_ky-start_ky

        num_k_1 = 86
        k_list1 = np.arange(num_k_1)
        energy_list1 = get_energy_list(
            start_kx, start_ky, total_kx, total_ky, num_k_1)

        # Gamma to K
        start_kx, start_ky = 0, 0
        end_kx, end_ky = -2*math.pi / \
            (3*self.aM), 2*math.pi/(math.sqrt(3)*self.aM)

        total_kx = end_kx-start_kx
        total_ky = end_ky-start_ky

        num_k_2 = 100
        k_list2 = np.arange(num_k_2)
        energy_list2 = get_energy_list(
            start_kx, start_ky, total_kx, total_ky, num_k_2)

        # K to K'
        start_kx, start_ky = -2*math.pi / \
            (3*self.aM), 2*math.pi/(math.sqrt(3)*self.aM)
        # end_kx, end_ky = -4*math.pi/(3*self.aM), 0
        end_kx, end_ky = -4*math.pi / \
            (3*self.aM), 4*math.pi/(math.sqrt(3)*self.aM)
        total_kx = end_kx-start_kx
        total_ky = end_ky-start_ky

        num_k_3 = 100
        k_list3 = np.arange(num_k_3)
        energy_list3 = get_energy_list(
            start_kx, start_ky, total_kx, total_ky, num_k_3)

        k_list = np.append(k_list1, k_list2+num_k_1)
        energy_list = np.append(energy_list1, energy_list2, axis=0)

        k_list = np.append(k_list, k_list3+num_k_1+num_k_2)
        energy_list = np.append(energy_list, energy_list3, axis=0)

        filename = os.path.join(
            path_head, "data", "band_structure", "Pre_NoMoire4_Dispersion_D_"+str(self.D)+".txt")
        with open(filename, "w") as f:
            l = len(energy_list)
            assert l == num_k_1+num_k_2+num_k_3
            for i in range(l):
                f.write(str(k_list[i])+"\t"+str(energy_list[i][0])+"\t"+str(energy_list[i][1])+"\t"+str(energy_list[i][2])+"\t"+str(energy_list[i][3])+"\t"+str(energy_list[i][4])+"\t"+str(energy_list[i][5])+"\t"+str(energy_list[i][6])+"\t"+str(energy_list[i][7])+"\n")

    def cal_W_gap(self, with_moire=True):
        B1_x, B1_y = 0, -4*math.pi/(math.sqrt(3)*self.aM)
        B2_x, B2_y = -2*math.pi/self.aM, 2*math.pi/(math.sqrt(3)*self.aM)

        num = 50
        energy_list = []
        for i in range(num):
            for j in range(num):
                kx = (i*B1_x+j*B2_x)/num
                ky = (i*B1_y+j*B2_y)/num

                if with_moire:
                    H = self.get_H(kx, ky)
                else:
                    H = self.get_H0(kx, ky)

                energy, states = LA.eigh(H)
                energy = np.real(energy)

                sort_perm = energy.argsort()

                energy = energy[sort_perm]
                energy_list.append(
                    energy[n_layer*self.M*self.M-2:n_layer*self.M*self.M+2])

        energy_list = np.array(energy_list)
        vv_list = energy_list[:, 0]
        v_list = energy_list[:, 1]
        c_list = energy_list[:, 2]
        cc_list = energy_list[:, 3]

        W_vv = max(vv_list)-min(vv_list)
        W_v = max(v_list)-min(v_list)
        W_c = max(c_list)-min(c_list)
        W_cc = max(cc_list)-min(cc_list)

        gap_vv_v = min(v_list)-max(vv_list)
        gap_v_c = min(c_list)-max(v_list)
        gap_c_cc = min(cc_list)-max(c_list)

        print self.D, W_vv, W_v, W_c, W_cc, gap_vv_v, gap_v_c, gap_c_cc
        return 

    def cal_bandwidth(self, band_num=2, with_moire=True):
        B1_x, B1_y = 0, -4*math.pi/(math.sqrt(3)*self.aM)
        B2_x, B2_y = -2*math.pi/self.aM, 2*math.pi/(math.sqrt(3)*self.aM)

        num = 50
        energy_list = []
        for i in range(num):
            for j in range(num):
                kx = (i*B1_x+j*B2_x)/num
                ky = (i*B1_y+j*B2_y)/num

                if with_moire:
                    H = self.get_H(kx, ky)
                else:
                    H = self.get_H0(kx, ky)
                # H = self.get_h(kx, ky)

                energy, states = LA.eigh(H)
                energy = np.real(energy)
                # states=np.asarray(states)

                sort_perm = energy.argsort()

                energy = energy[sort_perm]

                # states = states[:, sort_perm]
                energy_list.append(energy[n_layer*self.M*self.M-3+band_num])
        bandwidth = max(energy_list)-min(energy_list)
        return bandwidth

    def cal_DOS(self, band_num=2, bz_num=300, divide=100, with_moire=True):
        # + Gaussian Broadening for supplementary
        B1_x, B1_y = 0, -4*math.pi/(math.sqrt(3)*self.aM)
        B2_x, B2_y = -2*math.pi/self.aM, 2*math.pi/(math.sqrt(3)*self.aM)

        num = bz_num
        energy_list = []
        for i in range(num):
            for j in range(num):
                kx = (i*B1_x+j*B2_x)/num
                ky = (i*B1_y+j*B2_y)/num

                if with_moire:
                    H = self.get_H(kx, ky)
                else:
                    H = self.get_H0(kx, ky)
        
                energy, states = LA.eigh(H)
                energy = np.real(energy)
        
                sort_perm = energy.argsort()
                energy = energy[sort_perm]
                energy_list.append(energy[n_layer*self.M*self.M-1])

        bandwidth = max(energy_list)-min(energy_list)
        max_e = max(energy_list)
        min_e = min(energy_list)
        band_e = np.linspace(min_e, max_e, divide+1, endpoint=True)
        band_count = np.zeros(divide+1, dtype=float)

        for p in energy_list:
            e_index = int((p-min_e)/bandwidth*divide)
            band_count[e_index] = band_count[e_index]+1

        f = normalizer(band_count)
        band_count = f(gaussian_cov(
            band_count, sigma=0, period=bandwidth*1.0/divide))

        file_path = os.path.join(
            path_head, "data", "DOS", "DOS_D_"+str(self.D)+"sigma_0.csv")
        with open(file_path, "w") as f:
            for i in range(divide+1):
                f.write(str(band_e[i])+" , "+str(band_count[i])+'\n')

    def cal_DOS_width(self, band_num=2, bz_num=150, percentage_list=[12, 88],  with_moire=True):
        # For Gaussian: the FWHM is corresponding to 12% to 88%
        B1_x, B1_y = 0, -4*math.pi/(math.sqrt(3)*self.aM)
        B2_x, B2_y = -2*math.pi/self.aM, 2*math.pi/(math.sqrt(3)*self.aM)

        num = bz_num
        energy_list = []
        for i in range(num):
            for j in range(num):
                kx = (i*B1_x+j*B2_x)/num
                ky = (i*B1_y+j*B2_y)/num

                if with_moire:
                    H = self.get_H(kx, ky)
                else:
                    H = self.get_H0(kx, ky)

                energy, states = LA.eigh(H)
                energy = np.real(energy)
        
                sort_perm = energy.argsort()
                energy = energy[sort_perm]

                energy_list.append(energy[n_layer*self.M*self.M-1])

        energy_list.sort()
        index_list = []

        for p in percentage_list:
            index_list.append(energy_list[int(num*num*p/100)])
        print index_list[0], index_list[1]

    def cal_Ef(self, band_index=1, N=N_k, percentage=50.0,  with_moire=True):
        # Calculate Fermi level for partial fillings
        P_plus_list, P_minus_list, gap_list, E_list = self.load_P_n(N, n=2)
        energy_list = E_list[band_index, :]
        energy_list.sort()
        index = int(len(energy_list)*percentage/100)
        if index == len(energy_list):
            index = index-1
        return energy_list[index]

    def save_P(self, N=N_k):  # Matrix element for inter-band transition |<c|P|v>|^2
        # P4 means only consider 4 transitions
        path_pplus, path_pminus, path_gap = get_P_path(
            D=self.D, path_head=path_head, N_tran=4)

        if os.path.isfile(path_gap) and os.path.isfile(path_pplus) and os.path.isfile(path_pminus):
            print("Use Existed Data!")
            return

        print("Calculate the Matrix Elements at D="+str(self.D))
        t_P_start = time.time()

        # 0: v1 to c1; 1: v1 to c2; 2: v2 to c1; 3: v2 to c2
        P_plus_list = np.zeros((4, N*N))
        P_minus_list = np.zeros((4, N*N))
        gap_list = np.zeros((4, N*N))

        for i in range(N):
            for j in range(N):
                kx = i*self.Gx1/N+j*self.Gx2/N
                ky = i*self.Gy1/N+j*self.Gy2/N

                H = self.get_H(kx, ky)

                energy, states = LA.eigh(H)
                energy = np.real(energy)
                
                sort_perm = energy.argsort()

                energy = energy[sort_perm]
                states = states[:, sort_perm]

                c_state = states[:, n_layer*self.M*self.M]
                v_state = states[:, n_layer*self.M*self.M-1]

                cc_state = states[:, n_layer*self.M*self.M+1]
                vv_state = states[:, n_layer*self.M*self.M-2]

                E_c = energy[n_layer*self.M*self.M]
                E_v = energy[n_layer*self.M*self.M-1]

                E_cc = energy[n_layer*self.M*self.M+1]
                E_vv = energy[n_layer*self.M*self.M-2]

                HP_plus = self.get_H_dx(kx, ky)+1j*self.get_H_dy(kx, ky)
                HP_minus = self.get_H_dx(kx, ky)-1j*self.get_H_dy(kx, ky)

                A = np.dot(np.conj(c_state), np.dot(HP_plus, v_state))
                P_plus_list[0, i*N+j] = np.square(np.abs(A))
                A = np.dot(np.conj(c_state), np.dot(HP_minus, v_state))
                P_minus_list[0, i*N+j] = np.square(np.abs(A))

                A = np.dot(np.conj(cc_state), np.dot(HP_plus, v_state))
                P_plus_list[1, i*N+j] = np.square(np.abs(A))
                A = np.dot(np.conj(cc_state), np.dot(HP_minus, v_state))
                P_minus_list[1, i*N+j] = np.square(np.abs(A))

                A = np.dot(np.conj(c_state), np.dot(HP_plus, vv_state))
                P_plus_list[2, i*N+j] = np.square(np.abs(A))
                A = np.dot(np.conj(c_state), np.dot(HP_minus, vv_state))
                P_minus_list[2, i*N+j] = np.square(np.abs(A))

                A = np.dot(np.conj(cc_state), np.dot(HP_plus, vv_state))
                P_plus_list[3, i*N+j] = np.square(np.abs(A))
                A = np.dot(np.conj(cc_state), np.dot(HP_minus, vv_state))
                P_minus_list[3, i*N+j] = np.square(np.abs(A))

                gap_list[0, i*N+j] = E_c-E_v
                gap_list[1, i*N+j] = E_cc-E_v
                gap_list[2, i*N+j] = E_c-E_vv
                gap_list[3, i*N+j] = E_cc-E_vv
        t_P_end = time.time()
        print("Calculation Time for D="+str(self.D) +
              " is {} seconds.".format(t_P_end-t_P_start))
        np.save(path_pplus, P_plus_list)
        np.save(path_pminus, P_minus_list)
        np.save(path_gap, gap_list)

    def load_P(self, N=N_k):
        path_pplus, path_pminus, path_gap = get_P_path(
            D=self.D, path_head=path_head, N_tran=4)

        if not (os.path.isfile(path_gap) and os.path.isfile(path_pplus) and os.path.isfile(path_pminus)):
            self.save_P(N)

        P_plus_list = np.load(path_pplus)
        P_minus_list = np.load(path_pminus)
        gap_list = np.load(path_gap)

        return P_plus_list, P_minus_list, gap_list

    # calculate interband transition rate at a give nfrequency omega;
    # N indicates the number of points in BZ
    def save_P_n(self, N=N_k, n=2, with_moire=True):
        # Matrix element for inter-band transition |<c|P|v>|^2

        path_pplus, path_pminus, path_gap, path_e = get_P_n_path(
            D=self.D, n=n, path_head=path_head)

        print("Calculate the Matrix Elements at D="+str(self.D))
        t_P_start = time.time()

        # 2*n bands
        N_tran = n*(2*n-1)
        P_plus_list = np.zeros((N_tran, N*N))
        P_minus_list = np.zeros((N_tran, N*N))
        gap_list = np.zeros((N_tran, N*N))
        E_list = np.zeros((2*n, N*N))

        mid_index = n_layer*self.M*self.M

        for i in range(N):
            for j in range(N):
                kx = i*self.Gx1/N+j*self.Gx2/N
                ky = i*self.Gy1/N+j*self.Gy2/N

                if with_moire:
                    H = self.get_H(kx, ky)
                else:
                    H = self.get_H0(kx, ky)

                energy, states = LA.eigh(H)
                energy = np.real(energy)

                sort_perm = energy.argsort()

                energy = energy[sort_perm]
                states = states[:, sort_perm]

                state_list = states[:, mid_index-n:mid_index+n]
                energy_list = energy[mid_index-n:mid_index+n]
                E_list[:, i*N+j] = energy_list

                HP_plus = self.get_H_dx(kx, ky)+1j*self.get_H_dy(kx, ky)
                HP_minus = self.get_H_dx(kx, ky)-1j*self.get_H_dy(kx, ky)
                
                tran_index = 0
                for _i in range(2*n-1):
                    for _j in range(2*n-1-_i):
                        A = np.dot(
                            np.conj(state_list[:, _i+_j+1]), np.dot(HP_plus, state_list[:, _i]))
                        P_plus_list[tran_index, i*N+j] = np.square(np.abs(A))
                        A = np.dot(
                            np.conj(state_list[:, _i+_j+1]), np.dot(HP_minus, state_list[:, _i]))
                        P_minus_list[tran_index, i*N+j] = np.square(np.abs(A))
                        gap_list[tran_index, i*N +
                                 j] = energy_list[_i+_j+1]-energy_list[_i]
                        tran_index += 1

        t_P_end = time.time()
        print("Calculation Time for D="+str(self.D) +
              " is {} seconds.".format(t_P_end-t_P_start))
        np.save(path_pplus, P_plus_list)
        np.save(path_pminus, P_minus_list)
        np.save(path_gap, gap_list)
        np.save(path_e, E_list)

    def load_P_n(self, N, n):
        path_pplus, path_pminus, path_gap, path_e = get_P_n_path(
            D=self.D, n=n, path_head=path_head)

        if not (os.path.isfile(path_gap) and os.path.isfile(path_pplus) and os.path.isfile(path_pminus) and os.path.isfile(path_e)):
            self.save_P_n(N, n)

        P_plus_list = np.load(path_pplus)
        P_minus_list = np.load(path_pminus)
        gap_list = np.load(path_gap)
        E_list = np.load(path_e)

        return P_plus_list, P_minus_list, gap_list, E_list

    def save_wavefunc(self, N=N_k, n=2, with_moire=True):
        # Wave function

        path_wavefunc = os.path.join(
            path_head, "data", "wavefunc", "Wavefunc_n"+str(int(n))+"_D"+str(self.D*1.0)+".npy")

        print("Calculate the Wavefunctions at D="+str(self.D))
        t_P_start = time.time()

        # 2*n bands
        mid_index = n_layer*self.M*self.M
        dim_H = 2*mid_index
        wavefunc_list = np.zeros([dim_H, 2*n, N*N], dtype=complex)

        for i in range(N):
            for j in range(N):
                kx = i*self.Gx1/N+j*self.Gx2/N
                ky = i*self.Gy1/N+j*self.Gy2/N

                if with_moire:
                    H = self.get_H(kx, ky)
                else:
                    H = self.get_H0(kx, ky)

                energy, states = LA.eigh(H)
                energy = np.real(energy)

                sort_perm = energy.argsort()

                states = states[:, sort_perm]
                state_list = states[:, mid_index-n:mid_index+n]

                wavefunc_list[:, :, i*N+j] = state_list

        t_P_end = time.time()
        print("Calculation Time for D="+str(self.D) +
              " is {} seconds.".format(t_P_end-t_P_start))
        np.save(path_wavefunc, wavefunc_list)

    def inter_band(self, N, omega, t_index=0, N_tran=4, delta=0.1):
        # Calculate matrix element

        assert t_index in range(5)
        F = 0
        P_plus_list, P_minus_list, gap_list = self.load_P(N)

        for i in range(N):
            for j in range(N):

                index = i*N+j
                if t_index == 0:
                    for p in range(4):
                        # plust mean valley + or valley K;
                        A1 = P_plus_list[p, index]
                        A2 = P_minus_list[p, index]  # from the other valley K'
                        # inter-band gap at the momentum k;
                        gap = gap_list[p, index]
                        B = 1./((gap-omega)*(gap-omega)/delta +
                                delta) * (1./np.pi)  # joint DoS
                        F += A1*B+A2*B
                else:
                    p = t_index-1
                    A1 = P_plus_list[p, index]
                    # from the other valley K'
                    A2 = P_minus_list[p, index]
                    # inter-band gap at the momentum k;
                    gap = gap_list[p, index]
                    B = 1./((gap-omega)*(gap-omega)/delta+delta) * \
                        1./np.pi  # joint DoS
                    F += A1*B+A2*B

        return F/(N*N)

    def inter_band_n(self, N,  omega, n, tran_list, delta=0.1):
        F = 0
        P_plus_list, P_minus_list, gap_list, E_list = self.load_P_n(N, n)

        for i in range(N):
            for j in range(N):
                index = i*N+j
                for tran in tran_list:
                    A1 = P_plus_list[tran, index]
                    A2 = P_minus_list[tran, index]
                    gap = gap_list[tran, index]
                    B = 1./((gap-omega)*(gap-omega)/delta +
                            delta) * (1./np.pi)  # joint DoS
                    F += A1*B+A2*B

        return F/(N*N)

    def inter_band_partial(self, N,  omega, n, band_index=1, Ef=0, delta=0.1):
        F = 0
        P_plus_list, P_minus_list, gap_list, E_list = self.load_P_n(N, n)

        tran_list_empty = [0, 1, 2]
        tran_list_full = [1, 2, 3, 4]

        for i in range(N):
            for j in range(N):
                index = i*N+j
                E = E_list[band_index, index]
                if E <= Ef:
                    tran_list = tran_list_full
                else:
                    tran_list = tran_list_empty
                for tran in tran_list:
                    A1 = P_plus_list[tran, index]
                    A2 = P_minus_list[tran, index]
                    gap = gap_list[tran, index]
                    B = 1./((gap-omega)*(gap-omega)/delta +
                            delta) * (1./np.pi)  # joint DoS
                    F += A1*B+A2*B
        return F/(N*N)


    def plot_spectrum(self, N, t_index=0, N_tran=4, sigma=0, save_P=False, save_type="png", read_file=True):
        # Plot optical spectrum;
        # The BZ will be divided into N*N points

        # index=0 represents all the four transitions
        # else indicates the specific transition
        assert t_index in range(N_tran+1)
        assert save_type in ["show", "png", "txt"]

        # force to recalculate matrix elements
        if save_P:
            self.save_P(N)

        w_list = np.arange(N_w)*w_range/N_w+w_low
        I_list = np.zeros(N_w,)
        D_index = int(self.D)-D_min

        read_filename = os.path.join(path_head, 'data', '2D_inter_band.npy')
        if os.path.isfile(read_filename) and read_file:
            I_matrix = np.load(read_filename)
            I_list = I_matrix[D_index]
        else:
            for i in range(N_w):
                omega = w_list[i]
                I_list[i] = self.inter_band(
                    N, omega, t_index=t_index, N_tran=N_tran)/omega
        I_list = list(I_list)

        if sigma != 0:
            f = normalizer(I_list)
            I_list = f(gaussian_cov(I_list, sigma=sigma))

        if not save_type in ["png", "txt", "csv"]:
            plt.figure(2)
            plt.plot(w_list, I_list)
            plt.show()

        _filename = 'spectrum_D_'+str(self.D)+"_"+str(N_tran)+"trans"
        if sigma != 0:
            _filename = _filename+"_sigma"+str(sigma)
        if t_index != 0:
            _filename = _filename+'_I'+str(t_index)

        if save_type == "png":
            plt.figure(2)
            plt.plot(w_list, I_list)
            filename = os.path.join(
                path_head, 'image', 'spectrum', _filename+'.png')
            plt.savefig(filename)

        if save_type == "txt":
            filename = os.path.join(
                path_head, 'data', 'spectrum', _filename+'.txt')
            with open(filename, "w") as f:
                for i in range(N_w):
                    f.write(str(w_list[i])+" "+str(I_list[i])+"\n")

    def plot_spectrum_n(self, N, sigma=0, save_P=False, save_type="png", read_file=True):

        # The BZ will be divided into N*N points
        assert save_type in ["show", "png", "txt"]

        # force to recalculate matrix elements
        n = 2
        tran_list_fill = [1, 2, 3, 4]  # CNP
        # tran_list = [0, ]  # vv to v
        # tran_list = [0, 1, 2]  # Full Filling Gap
        if save_P:
            self.save_P_n(N, n, with_moire=True)

        w_list = np.arange(N_w)*w_range/N_w+w_low
        I_list = np.zeros(N_w,)
        D_index = int(self.D)-D_min

        read_filename = os.path.join(
            path_head, 'data', '2D_inter_band_'+str(int(n))+'.npy')
        if os.path.isfile(read_filename) and read_file:
            I_matrix = np.load(read_filename)
            I_list = I_matrix[D_index]
        else:
            for i in range(N_w):
                omega = w_list[i]
                I_list[i] = self.inter_band_n(
                    N, omega, n, tran_list)/omega
        I_list = list(I_list)

        if sigma != 0:
            f = normalizer(I_list)
            I_list = f(gaussian_cov(I_list, sigma=sigma))

        if not save_type in ["png", "txt", "csv"]:
            plt.figure(2)
            plt.plot(w_list, I_list)
            plt.show()

        _filename = 'D_'+str(self.D)+'_vv_to_v'
        if sigma != 0:
            _filename = _filename+"_sigma"+str(sigma)

        if save_type == "png":
            plt.figure(2)
            plt.plot(w_list, I_list)
            filename = os.path.join(
                path_head, 'image', 'spectrum', _filename+'.png')
            plt.savefig(filename)

        if save_type == "txt":
            filename = os.path.join(
                path_head, 'data', 'spectrum', _filename+'.txt')
            with open(filename, "w") as f:
                for i in range(N_w):
                    f.write(str(w_list[i])+" "+str(I_list[i])+"\n")

    def plot_spectrum_partial(self, N, sigma=0, percentage=50.0, save_P=False, save_type="png", read_file=False):
        # The BZ will be divided into N*N points

        assert save_type in ["show", "png", "txt"]

        n = 2
        if save_P:
            self.save_P_n(N, n, with_moire=True)

        Ef = self.cal_Ef(percentage=percentage)
        print Ef
        w_list = np.arange(N_w)*w_range/N_w+w_low
        I_list = np.zeros(N_w,)
        D_index = int(self.D)-D_min

        read_filename = os.path.join(
            path_head, 'data', '2D_inter_band_'+str(int(n))+'.npy')
        if os.path.isfile(read_filename) and read_file:
            I_matrix = np.load(read_filename)
            I_list = I_matrix[D_index]
        else:
            for i in range(N_w):
                omega = w_list[i]
                I_list[i] = self.inter_band_partial(
                    N, omega, n, band_index=1, Ef=Ef)/omega
        I_list = list(I_list)

        if sigma != 0:
            f = normalizer(I_list)
            I_list = f(gaussian_cov(I_list, sigma=sigma))

        if not save_type in ["png", "txt", "csv"]:
            plt.figure(2)
            plt.plot(w_list, I_list)
            plt.show()

        _filename = 'spec_partial_D_'+str(self.D)+'_p'+str(percentage)
        if sigma != 0:
            _filename = _filename+"_sigma"+str(sigma)

        if save_type == "png":
            plt.figure(2)
            plt.plot(w_list, I_list)
            plt.ylim(0, 4500)
            filename = os.path.join(
                path_head, 'image', 'spectrum', _filename+'.png')
            plt.savefig(filename)
            plt.close()

        if save_type == "txt":
            filename = os.path.join(
                path_head, 'data', 'spectrum', _filename+'.txt')
            with open(filename, "w") as f:
                for i in range(N_w):
                    f.write(str(w_list[i])+" "+str(I_list[i])+"\n")

    def save_spectrum_origin(self):
        # The function for output origin data

        w_list = np.arange(N_w)*w_range/N_w+w_low
        I_list = np.zeros(N_w,)

        l = 15
        D_list = [-87, -79, -71, -62, -53, -44, -
                  37, -30, 36, 44, 52, 59, 66, 74, 83]
        sigma_list = [1.1]*l

        read_filename = os.path.join(path_head, 'data', '2D_inter_band.npy')
        I_matrix = np.load(read_filename)

        output_matrix = np.zeros([l+1, N_w], dtype=float)
        output_matrix[0] = w_list

        for i in range(l):
            D_index = int(D_list[i])-D_min
            sigma = sigma_list[i]
            I_list = I_matrix[D_index]
            I_list = list(I_list)
            if sigma != 0:
                f = normalizer(I_list)
                I_list = f(gaussian_cov(I_list, sigma=sigma))
            output_matrix[i+1] = I_list

        _filename = 'spectrum_DList_sigma1.1.txt'
        filename = os.path.join(path_head, 'final_figures', _filename)
        with open(filename, "w") as f:
            for i in range(N_w):
                string = ""
                for x in output_matrix[:, i]:
                    string += str(x)+' '
                f.write(string+'\n')

    def save_matrix_2D_inter_band(self, N=N_k, t_index=0, save_P=True):
        print("Save Matrix Elements for 2D Inter Band: I{}".format(t_index))

        D_list = D_min+np.arange(N_D)
        w_list = np.arange(N_w)*w_range/N_w+w_low
        I_list = np.zeros((N_D, N_w))

        for i in range(N_D):
            self.set_D(D_list[i])
            if save_P:
                self.save_P(N)

            for j in range(N_w):
                omega = w_list[j]
                I_list[i, j] = self.inter_band(N, omega, t_index=t_index)/omega

        if t_index == 0:
            filename = os.path.join(
                path_head, 'data', '2D_inter_band.npy')
        else:
            filename = os.path.join(
                path_head, 'data', '2D_inter_band_I'+str(t_index)+'.npy')

        np.save(filename, I_list)

    def save_matrix_2D_inter_band_n(self, N=N_k, n=2, save_P=True):

        D_list = D_min+np.arange(N_D)
        w_list = np.arange(N_w)*w_range/N_w+w_low
        I_list = np.zeros((N_D, N_w))

        tran_list = [0, 1, 2, 3, 4]
        for i in range(N_D):
            self.set_D(D_list[i])
            if save_P:
                self.save_P_n(N, n)

            for j in range(N_w):
                omega = w_list[j]
                I_list[i, j] = self.inter_band_n(
                    N, omega, n, tran_list)/omega

        filename = os.path.join(
            path_head, 'data', '2D_inter_band_2.npy')

        np.save(filename, I_list)

    def plot_2D_spectrum(self, sigma=0.0, t_index=0, save_type="txt"):

        print("Plot 2D Inter Band: I{}".format(t_index))
        D_list = D_min+np.arange(N_D)
        w_list = np.arange(N_w)*w_range/N_w+w_low
        I_data = np.zeros((N_D*N_w, 7))
        for i in range(N_D):
            for j in range(N_w):
                I_data[i*N_w+j, 0:2] = [D_list[i], w_list[j]]

        if t_index == 0:
            read_filename = os.path.join(
                path_head, 'data', '2D_inter_band.npy')
        else:
            read_filename = os.path.join(
                'data', '2D_inter_band_I'+str(t_index)+'.npy')
        I_list = np.load(read_filename)

        for i in range(N_D):
            spec = I_list[i]
            if sigma != 0:
                f = normalizer(spec)
                spec = f(gaussian_cov(spec, sigma))
            I_list[i] = spec
            for j in range(N_w):
                I_data[i*N_w+j, t_index+2] = I_list[i, j]

        _filename = '2D_spectrum'
        if sigma != 0:
            _filename = _filename+"_sigma"+str(sigma*1.0)
        if save_type == "txt":
            filename = os.path.join(path_head, "data", _filename+".txt")
            np.savetxt(filename, I_data)

    def plot_2D_spectrum_n(self, sigma=0.0, save_type="txt"):

        D_list = D_min+np.arange(N_D)
        w_list = np.arange(N_w)*w_range/N_w+w_low
        I_data = np.zeros((N_D*N_w, 7))
        for i in range(N_D):
            for j in range(N_w):
                I_data[i*N_w+j, 0:2] = [D_list[i], w_list[j]]

        read_filename = os.path.join(path_head, 'data', '2D_inter_band_2.npy')
        I_list = np.load(read_filename)

        for i in range(N_D):
            spec = I_list[i]
            if sigma != 0:
                f = normalizer(spec)
                spec = f(gaussian_cov(spec, sigma))
            I_list[i] = spec
            for j in range(N_w):
                I_data[i*N_w+j, 2] = I_list[i, j]

        _filename = '2D_spec_2'
        if sigma != 0:
            _filename = _filename+"_sigma"+str(sigma*1.0)
        if save_type == "txt":
            filename = os.path.join(path_head, "data", _filename+".txt")
            np.savetxt(filename, I_data)

    def spectrum_stat(self, sigma=0.0):
        D_list = D_min+np.arange(N_D)
        w_list = np.arange(N_w)*w_range/N_w+w_low
        I_list = np.load(os.path.join(path_head, 'data', '2D_inter_band.npy'))
        stat_list = np.zeros([N_D, 5])
        for i in range(N_D):
            spec = I_list[i]

            f = normalizer(spec)
            spec = f(gaussian_cov(spec, sigma))
            I_list[i] = spec

            max_I = max(spec)
            max_index = spec.index(max_I)
            left_index = max_index
            for j in range(max_index):
                left_index = max_index-j-1
                if spec[left_index] < (max_I/2):
                    break
            right_index = max_index
            for j in range(N_w-max_index-1):
                right_index = max_index+j+1
                if spec[right_index] < (max_I/2):
                    break

            max_w = w_list[max_index]
            left_w = w_list[left_index]
            right_w = w_list[right_index]
            stat_list[i] = [D_list[i], max_w, left_w, right_w, right_w-left_w]

        filename = os.path.join(
            path_head, "data", "spectrum_stat_sigma"+str(sigma)+".txt")
        np.savetxt(filename, stat_list)

    def plot_BZ_spectrum(self, N=N_k, sigma=0, percentage=0.0, save_P=False, save_type="png", filename="MBZ"):
        assert save_type in ["show", "png", "txt"]

        # Set range
        w_list = np.arange(N_w)*w_range/N_w+w_low

        # Valence Band
        n = 2
        band_index = n-1
        tran_list_empty = [0, 1, 2]
        tran_list_full = [1, 2, 3, 4]
        if save_P:
            self.save_P_n(N, n, with_moire=True)
        Ef = self.cal_Ef(percentage=percentage)

        print "Fermi Level:", Ef

        P_plus_list, P_minus_list, gap_list, E_list = self.load_P_n(N, n)
        intensity = np.zeros([N, N])
        delta = 0.1  # epsilon in calculation

        # calculate spectrum

        for i in range(N):
            for j in range(N):
                index = i*N+j
                E = E_list[band_index, index]
                if E <= Ef:
                    tran_list = tran_list_full
                else:
                    tran_list = tran_list_empty

                for omega in w_list:
                    F = 0
                    for tran in tran_list:
                        A1 = P_plus_list[tran, index]
                        A2 = P_minus_list[tran, index]
                        gap = gap_list[tran, index]
                        B = 1./((gap-omega)*(gap-omega)/delta +
                                delta) * (1./np.pi)  # joint DoS
                        F += A1*B+A2*B
                    intensity[i, j] += F/omega
        intensity = intensity/(N*N)
        self._draw_BZ_map(N, intensity, save_type=save_type, filename=filename)

    def _draw_BZ_map(self, N, dataset, save_type="show", filename="MBZ"):
        assert save_type in ["show", "png"]

        G1 = np.array([0.0, -1.0])
        G2 = np.array([np.sqrt(3)/2, -0.5])
        X, Y = np.zeros([N+1, N+1]), np.zeros([N+1, N+1])
        for i in range(N+1):
            for j in range(N+1):
                X[i, j], Y[i, j] = 1.0*G1*i/N+1.0*G2*j/N
        
        fig, ax = plt.subplots()
        colormap = plt.get_cmap("viridis")
        im = ax.pcolormesh(X, Y, dataset, cmap=colormap)
        
        fig.colorbar(im, ax=ax, label="overlap(Cosine)")
        ax.set_xlabel(r"$k_x$")
        ax.set_ylabel(r"$k_y$")

        text_delta = 0.025
        ax.text(0.2887-text_delta, -0.5-text_delta, "K'", color="white")
        ax.text(0.5773-text_delta, -1-text_delta, "K", color="white")
        ax.text(0.5*text_delta, 0-2.5*text_delta, r"$\Gamma$", color="white")
        ax.text(0.433, -0.25, "M", color="black")

        ax.set_aspect('equal')
        fig.tight_layout()

        if save_type == "show":
            plt.show()
        else:
            filename = os.path.join(
                path_head, "image", "MBZ_Mapping", filename+".png")
            plt.savefig(filename)

        plt.close()

    def get_gap_W(self):
        def get_energy_list(start_kx, start_ky, total_kx, total_ky, num_k):
            energy_list = np.zeros((num_k, 4))

            for i in range(num_k):
                kx = start_kx+i*total_kx/num_k
                ky = start_ky+i*total_ky/num_k

                H = self.get_H(kx, ky)

                energy, states = LA.eigh(H)
                energy = np.real(energy)
                # states=np.asarray(states)

                sort_perm = energy.argsort()

                energy = energy[sort_perm]
                # states = states[:, sort_perm]
                energy_list[i, :] = energy[n_layer *
                                           self.M*self.M-2:n_layer*self.M*self.M+2]
            return energy_list

        start_kx = 0
        end_kx = -4*math.pi/(3*self.aM)
        start_ky = 0
        end_ky = 0

        total_kx = end_kx-start_kx
        total_ky = end_ky-start_ky

        num_k = 10
        k_list1 = np.arange(num_k)
        energy_list1 = get_energy_list(
            start_kx, start_ky, total_kx, total_ky, num_k)

        start_kx = -4*math.pi/(3*self.aM)
        end_kx = -2*math.pi/(3*self.aM)
        start_ky = 0
        end_ky = 2*math.pi/(math.sqrt(3)*self.aM)

        total_kx = end_kx-start_kx
        total_ky = end_ky-start_ky

        num_k = 10
        k_list2 = np.arange(num_k)
        energy_list2 = get_energy_list(
            start_kx, start_ky, total_kx, total_ky, num_k)

        start_kx = -2*math.pi/(3*self.aM)
        end_kx = 4*math.pi/(3*self.aM)
        start_ky = 2*math.pi/(math.sqrt(3)*self.aM)
        end_ky = 0

        total_kx = end_kx-start_kx
        total_ky = end_ky-start_ky

        num_k = 15
        k_list3 = np.arange(num_k)
        energy_list3 = get_energy_list(
            start_kx, start_ky, total_kx, total_ky, num_k)

        num_k = 10
        start_kx = 4*math.pi/(3*self.aM)
        end_kx = 0
        start_ky = 0
        end_ky = 0

        total_kx = end_kx-start_kx
        total_ky = end_ky-start_ky
        k_list4 = np.arange(num_k)
        energy_list4 = get_energy_list(
            start_kx, start_ky, total_kx, total_ky, num_k)

        k_list = np.append(k_list1, k_list2+10)
        energy_list = np.append(energy_list1, energy_list2, axis=0)

        k_list = np.append(k_list, k_list3+20)
        energy_list = np.append(energy_list, energy_list3, axis=0)
        k_list = np.append(k_list, k_list4+35)
        energy_list = np.append(energy_list, energy_list4, axis=0)

        c_energy = energy_list[:, 2]
        v_energy = energy_list[:, 1]
        b_energy = energy_list[:, 0]
        t_energy = energy_list[:, 3]

        W = np.max(v_energy)-np.min(v_energy)
        gap = np.min(c_energy)-np.max(v_energy)
        gap_sc = np.min(t_energy)-np.max(c_energy)
        gap_sv = np.min(v_energy)-np.max(b_energy)

        if gap_sc < 0:
            gap_sc = 0
        if gap_sv < 0:
            gap_sv = 0

        return [W, gap]

    def plot_gap_W(self, N):  # plot the band gap and the bandwidth with U
        D_list = -200+np.arange(N)*(410./N)

        gap_list = np.zeros((N,))

        W_list = np.zeros((N,))

        gap_data = np.zeros((N, 3))

        for i in range(N):
            self.U = D_list[i]

            print self.U

            W_list[i], gap_list[i] = self.get_gap_W()
            gap_data[i, :] = [D_list[i], W_list[i], gap_list[i]]

        np.savetxt('data/D_list.txt', D_list)
        np.savetxt('data/W_list.txt', W_list)
        np.savetxt('data/neutrality_gap.txt', gap_list)

        plt.rc('text', usetex=True)
        plt.rc('axes', linewidth=3)
        plt.rc('font', weight='bold')
        plt.rcParams['text.latex.preamble'] = [
            r'\usepackage{sfmath} \boldmath']

        plt.figure()
        plt.plot(D_list, gap_list, linewidth=3, label=r'$\Delta(meV)$')
        plt.plot(D_list, W_list, linewidth=3, label=r'$W(meV)$')
        plt.xlabel(r'$D(meV)$', fontsize=20)

        plt.legend(fontsize=10, loc='upper right')
        plt.axhline(0, color='k')
        plt.xticks(fontsize=20)
        plt.yticks(fontsize=20)
        plt.tight_layout()

        plt.show()
        plt.savefig('image/gap_W.png')


def angle(v1, v2):
    assert len(v1) == len(v2)
    v1 = np.array(v1, dtype=complex)
    v2 = np.array(v2, dtype=complex)
    return abs(np.conjugate(v1).dot(v2)/(np.linalg.norm(v1)*np.linalg.norm(v2)))


def load_wavefunc(filename):
    path_wavefunc = os.path.join(
        path_head, "data", "wavefunc", filename+".npy")
    wavefunc_list = np.load(path_wavefunc)
    return wavefunc_list


def compare_wavefunc(file1, file2="VM0_D-44.0"):
    N = N_k
    wavefunc_list2=load_wavefunc(file2)
    wavefunc_list1 = load_wavefunc(file1)
    
    band_1, band_2 = 1, 1
    overlap = np.zeros([N, N])
    for _i in range(N):
        for _j in range(N):
            k_index = _i*N+_j
            wf1 = wavefunc_list1[:, band_1, k_index]
            wf2 = wavefunc_list2[:, band_2, k_index]
            overlap[_i, _j] = angle(wf1, wf2)

    return overlap


def gaussian_cov(line, sigma, period=0.25):
    """
        line
        sigma
        period: the step size in line, default value is 0.5
    """
    l = len(line)
    window = min(2*int(5*sigma/period)+1, l)
    if sigma != 0:
        print "std", sigma, period, sigma/period
        gs = signal.gaussian(window, std=sigma/period)
        result = np.convolve(line, gs, "same")
    else:
        result = np.array(line)

    return result.tolist()


def normalizer(line):
    norm = np.linalg.norm(line)

    def f(x):
        xnorm = np.linalg.norm(x)
        x = [i*(norm/xnorm) for i in x]
        return x
    return f


if __name__ == "__main__":
    t_start = time.time()

    p = GBN()
    
    for dd in range(201):
        p.set_D(-100+dd)
        p.cal_W_gap()
        
    # p.save_dispersion(with_moire=False)

    # p.save_P_n()
    # overlap = compare_wavefunc("VM0.1_D-44.0")
    # linecut=overlap[0]
    # min_Cos=min(linecut)
    # Y=[(1-x)/(1-min_Cos) for x in linecut]
    # X=np.linspace(0,1,num=len(linecut),endpoint=False)
    # plt.plot(X,Y)
    # plt.title("VM=0.1")
    # plt.savefig("Linecut_Overlap11_VM0_VM0.1.png")

    
    # p._draw_BZ_map(N=N_k, dataset=overlap, save_type="png",
    #    filename="test_VM0_VM1_cmp11")
    # p.save_wavefunc()
    # wavefunc_VM0=load_wavefunc("VM0_D-44.0")
    # k_index=0
    # band_index=1
    # nn=2*n_layer
    # site_list=np.zeros([p.M*p.M,nn],dtype=complex)

    # flag=False
    # for i in range(p.M*p.M):
    #     site_list[i,:]=wavefunc_VM0[nn*i:nn*i+nn,band_index,k_index]
    #     if sum(abs(site_list[i]))>0.00001:
    #         if flag:
    #             print "Non-diagonal!"
    #         else:
    #             flag=True
    # if flag:
    #     print 1
    # else:
    #     print 0

    # p.cal_DOS()
    # p.plot_BZ_spectrum(percentage=0.0, save_P=True, filename="Tran0_VM0")
    # p._draw_2D_map(N=100,dataset=[])
    # p.plot_spectrum_n(N=N_k, sigma=0, save_P=True, read_file=False)
    # print p.cal_Ef(percentage=5)
    # for i in range(9):
    #     print p.cal_Ef(percentage=i*10+10)
    # for i in range(11):
    #     pe = i*10
    #     p.plot_spectrum_partial(
    #         N=N_k, sigma=0, percentage=pe, save_P=True, read_file=False)
    # p.plot_spectrum_partial(
    # N=N_k, sigma=0, percentage=5.0, save_P=True, read_file=False)

    t_end = time.time()
    print((t_end-t_start)/3600, "hours")
